#include "IPTables.h"
#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include "Debug.h"
#include "IPAddress.h"

const char *IPTables::IPSetName()
{
	return "evil_hackers";
}

// Ban a list of IPs using ipset and restore.
// Set shouldAdd to true to add to ipset. 
// Set shouldAdd to false to remove from ipset.
void IPTables::BanIPsUsingRestore(std::forward_list<IPAddress*> &ipsToAddFast,bool shouldAdd)
{
	FILE *fOutput;
	const char *command = "sudo ipset restore";
	const int BUFFER_SIZE = 1024;
	char buffer[BUFFER_SIZE];
	const char *ipsetName;
	
	DEBUG;
	ipsetName = IPSetName();
	fOutput = popen(command,"w");
	if (fOutput == NULL) {
		throw "Error running command ipset restore";
	}
	for(auto loop = ipsToAddFast.begin();loop != ipsToAddFast.end();++loop) {
		if (shouldAdd) {
			snprintf(buffer,BUFFER_SIZE,"add %s %s\n",ipsetName,(*loop)->IP);
			fputs(buffer,fOutput);
			printf("Added %s\n",(*loop)->IP);
		} else {
			snprintf(buffer,BUFFER_SIZE,"del %s %s\n",ipsetName,(*loop)->IP);
			fputs(buffer,fOutput);
			printf("Removed %s\n",(*loop)->IP);
		}
	}
	pclose(fOutput);
	DEBUG;
}

// Create the ipset name. The program must already know it doesn't exist before calling.
void IPTables::CreateIpsetSetname()
{
	const int BUFFER_SIZE = 200;
	char buffer[BUFFER_SIZE];

	snprintf(buffer,BUFFER_SIZE,"sudo ipset create %s iphash",IPTables::IPSetName());
	printf("%s\n",buffer);
	system(buffer);
	return;
}

// Make sure iptables has the right ban rule for ipset.
void IPTables::MakeSureIptablesHasIpsetRule()
{
	const int BUFFER_SIZE = 200;
	char buffer[BUFFER_SIZE];
	bool hasIpsetRule;
	Tree<IPAddress> ipList;
	const char *ipSetName;

	// Look for rule in iptables.
	ipSetName = IPTables::IPSetName();
	ReadFromIptables(ipList,false,true);
	hasIpsetRule = false;
	for(auto loop = ipList.List.begin();loop != ipList.List.end();++loop) {
		if (strcasestr((*loop)->IP,ipSetName)!=nullptr) {
			hasIpsetRule = true;
			break;
		}
	}
	ipList.Clear();
	if (!hasIpsetRule) {
		snprintf(buffer,BUFFER_SIZE,"sudo iptables -A INPUT -m set --match-set %s src -j DROP",ipSetName);
		printf("%s\n",buffer);
		system(buffer);
	}
}

// Ban this IP using ipset or iptables.
void IPTables::BanThisIP(IPAddress &ip,bool useIpset)
{
	const int COMMAND_SIZE = 200;
	char command[COMMAND_SIZE];
	if (useIpset) {
		snprintf(command,COMMAND_SIZE,"sudo ipset add %s %s",IPTables::IPSetName(),ip.IP);
	} else {
		snprintf(command,COMMAND_SIZE,"sudo iptables -w -A INPUT -s %s -j DROP",ip.IP);
	}
	//printf("%s\n",command);
	command[COMMAND_SIZE-1] = 0;
	system(command);
}

// Clear the current list of IPs from iptables or ipset.
void IPTables::ClearCurrent(bool useIpset)
{
	Tree<IPAddress> ipList;
	if (useIpset) {
		ClearIPList(nullptr,useIpset);
	} else {
		ReadFromIptables(ipList,useIpset,false);
		ClearIPList(&(ipList.List),useIpset);
	}
	return;
}

// Clears a list of IPs from iptables or ipset.
void IPTables::ClearIPList(std::forward_list<IPAddress*> *ipList,bool useIpset)
{
	const int BUFFER_SIZE = 1024;
	char buffer[BUFFER_SIZE];	
	
	if (useIpset) {
		if ((ipList == nullptr) || (ipList->empty())) {
			// Clear the entire list quickly.
			snprintf(buffer,BUFFER_SIZE,"sudo ipset flush %s",IPTables::IPSetName());
			printf("%s\n",buffer);
			system(buffer);
		} else {
			// Quickly remove specific items.
			BanIPsUsingRestore(*ipList,false);
		}
	} else {		
		for(auto loop = ipList->begin();loop != ipList->end();++loop) {
			snprintf(buffer,BUFFER_SIZE,"sudo iptables -D INPUT -s %s/32 -j DROP",(*loop)->IP);
			printf("%s\n",buffer);
			system(buffer);			
		}
	}
}

IPAddress* IPTables::Parse(const char *line,bool fromIPset)
{
	int length;
	const char *search1;
	const char *search2;
	const int OFFSET = 12;
	IPAddress *parsed = nullptr;
	DEBUG;
	if (fromIPset) {
		length = strlen(line);		
		// Remove enter character.
		while (length > 0) {
			if (line[length-1] < ' ') {
				length--;
			} else {
				break;
			}
		}
		DEBUG;
		parsed = new IPAddress(line,length);
		DEBUG;
		if (parsed == nullptr) {
			throw "Out of memory in IPAddress.Parse";
		}
	} else {
		if (strncmp(line,"-A INPUT -s ",OFFSET)==0) {
			search1 = strstr(line,"-j DROP");
			if (search1 != nullptr) {
				search2 = strstr(line,"/");
				if (search2 == nullptr) {
					length = (search1-line)-OFFSET;
				} else {
					length = (search2-line)-OFFSET;
				}
				// Remove enter character.
				while (length > 0) {
					if (line[OFFSET+length-1] < ' ') {
						length--;
					} else {
						break;
					}
				}
				DEBUG;
				parsed = new IPAddress(line+OFFSET,length);			
				DEBUG;
				if (parsed == nullptr) {
					throw "Out of memory in IPAddress.Parse";
				}
			}
		}
	}
	return parsed;
}

// Reads all the IPs from the iptables program. Returns false if it ran out of memory.
bool IPTables::ReadFromIptables(Tree<IPAddress> &tree,bool useIpset,bool readAll)
{
		// If sort is true then the list is a binary tree. 
	// If sort if false then the list is a linked list.
	// If duplicates != nullptr then add duplicates to this duplicates list.
	FILE *fInput;
	IPAddress *ip;
	const int BUFFER_SIZE = 1024;
	char buffer[BUFFER_SIZE];
	bool outOfMemory;
	bool foundMembers;

	ip = nullptr;
	outOfMemory = false;
	DEBUG;
	if (useIpset) {
		snprintf(buffer,BUFFER_SIZE,"sudo ipset list %s",IPSetName());
		fInput = popen(buffer,"r");
		foundMembers = false;
	} else {
		fInput = popen("sudo iptables -w -S","r");
	}
	if (fInput != nullptr) {
		DEBUG;
		while (fgets(buffer,BUFFER_SIZE,fInput) != nullptr) {
			DEBUG;
			if (outOfMemory) {
				// Keep reading even though we're out of memory.
				continue;
			}		
			DEBUG;
			// Parse buffer to create an IPAddress object.
			if (useIpset) {
				// Use ipset
				DEBUG;
				if ((readAll) || (foundMembers)) {
					try
					{
						ip = Parse(buffer,useIpset);
					} catch(const char *error) {
						ip = nullptr;
						outOfMemory = true;
					}
				} else {
					if (strncasecmp(buffer,"Members:",8)==0) {
						foundMembers = true;
					}						
				}
			} else {
				// Use iptables.		
				try
				{
					ip = Parse(buffer,useIpset);
				} catch(const char *error) {
					ip = nullptr;
					outOfMemory = true;
				}
			}
			// Add the IPAddress object to the tree.
			DEBUG;
			if ((ip != nullptr) && (!outOfMemory)) {
				if (!tree.Add(&ip,IPAddress::Compare)) {
					outOfMemory = true;
				}
			}
			// Delete the IPAddress object if it wasn't added to the tree.
			if (ip != nullptr) {
				delete ip;
				ip = nullptr;
			}				
			DEBUG;
		}
		DEBUG;
		pclose(fInput);
	}
	if (outOfMemory) {
		return false;
	}
	if ((useIpset) && (!foundMembers)) {
		// If Members: wasn't found then the set doesnt' exist. Create it.
		CreateIpsetSetname();
		MakeSureIptablesHasIpsetRule();
	}
	return true;
}